# -*- coding: utf-8 -*-
import os
import numpy as np
from scikits.audiolab import wavread
import scikits.talkbox
import fastdtw
from scikits.talkbox import features
from scikits.talkbox.features import mfcc

weight=[0,1,1,1,1,1,1,1,1,1,1,1,1]
fdict={}

def getfeature(x):
    if(fdict.has_key(x)):
        return fdict[x]        
    else:
        data, fs, enc=wavread(x)
        temp=mfcc(data)[0]    
        fdict[x]=temp
        return temp
    
def mydist(a, b):
    i_dist=0
    i_len=0
    for i in range(13):
        if(weight[i]>0):
            i_dist=i_dist+abs(a[i]-b[i])
            i_len=i_len+1
    return i_dist/i_len
    
def dist(mf1, mf2):    
    re=fastdtw.fastdtw(mf1, mf2, 1, mydist)[0]
    return re
        
def calcsimilar(x, y):
    fx=getfeature(x)    
    fy=getfeature(y)    
    d = dist(fx, fy)
    return d
    
# 获取指定路径下所有指定后缀的文件
# dir 指定路径
# ext 指定后缀，链表&不需要带点 或者不指定。例子：['xml', 'java']
def GetFileFromThisRootDir(rootdir,filelist):        
    for i in os.listdir(rootdir):
        filepath = os.path.join(rootdir,i)
        if os.path.isdir(filepath):
            GetFileFromThisRootDir(filepath, filelist)
        elif i.endswith('.wav') and os.path.exists(filepath):
            filelist.append(filepath)


def calcfactor(path, name, template):
    ftemplate=getfeature(template)
    files=[]
    min_dist=999999
    max_dist=0
    GetFileFromThisRootDir(path, files)
    for filename in files:
        f=getfeature(filename)
        d=dist(f,ftemplate)
        if name+".wav" in filename:                        
            if(d>max_dist):
                max_dist=d
        else:
            if(d<min_dist):
                min_dist=d
    return (min_dist-max_dist)/max_dist
                
def train(path, name, template):    
    dist=-99999
    re_i=-1
    re_j=-1
    re_m=-1
    re_n=-1
    for n in range(13):
        for m in range(13):
            for i in range(13):
                for j in range(13):
                    for k in range(13):
                        weight[k]=1
                    weight[i]=0
                    weight[j]=0
                    weight[m]=0
                    weight[n]=0
                    d=calcfactor(path, name, template)
                    print d
                    if(d>dist):
                        dist=d
                        re_i=i
                        re_j=j   
                        re_m=m   
                        re_n=n          
    return dist, re_i, re_j, re_m, re_n
  
  
if __name__ == '__main__':      
    getfeature("D:\\GTT\\Demo\\DeepLearn\\py\\UI\\base_p\\p.wav")
    getfeature("D:\\GTT\\Demo\\DeepLearn\\py\\UI\\base_p\\p.wav")
    print calcfactor("D:\\GTT\\Demo\\DeepLearn\\py\\UI\\base_p\\", "p", "D:\\GTT\\Demo\\DeepLearn\\py\\UI\\base_p\\p.wav")
    print train("D:\\GTT\\Demo\\DeepLearn\\py\\UI\\base_p\\", "p", "D:\\GTT\\Demo\\DeepLearn\\py\\UI\\base_p\\p.wav")
    